#coding=utf-8
# Copyright (c) 2016 Tinydot. inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

from ml_feature.util import get_config_type_bytes

def _(func_name, *args):
    '''
    使用此函数调用相应的函数
    :param func_name: 将要调用的函数
    :param args: 配置文件中显示调用的函数参数
    :return: 返回将要调用的函数的返回值
    '''
    temp_list = []
    for i,s in enumerate(args):
        locals()['a'+str(i)] = s
        temp_list.append('a'+str(i))
    op = getattr(_,'op',None)
    get_num = len(args)
    want_num = eval(func_name+'.num')
    callfunc = func_name+'('
    if want_num-get_num==1:
        temp_list.append('op')
    return eval(callfunc + ','.join(temp_list) + ')')


def prod(target_list):
    '''
    计算列表所有数的连乘，target_list里面的元素不能有list和数字并存的情况
    :param target_list: 目标列表，计算此列表的连乘
    :return: 返回连乘结果
    '''
    if type(target_list[0])==type(list()):
        result = 0
    else:
        result = 1
    for temp in target_list:
        if type(temp)==type(list()):
            a = 1
            for i in temp:
                a *= int(i)
            result += a
        else:
            result *= int(temp)
    return result
prod.num=1

def getInputs(size, op):
    '''
    普通op的输入读取函数
    从op的输入里将各个输入的确切样式算出，并加入inputs_list
    :param size: 需要解析的inputs数量
    :param op: 被解析的对象，为TensorFlow的模型中的op对象
    :return: 返回解析完成的列表
    '''
    inputs_list=[]
    for index in range(size):
        inputs_index = op.inputs[index]
        inputs_list.append([int(i) for i in inputs_index.shape])
    return inputs_list
getInputs.num=2

def getInputsConcat(size, op):
    '''
    Concat专用的输入读取函数
    从op的输入里将第一个输入的值和其余输入的样式加入inputs_list
    :param size:需要解析的值和样式数量
    :param op:被解析的对象，此处一定为Concat
    :return: 返回解析完成的列表
    '''
    inputs_list = []
    for index in range(size):
        inputs_index = op.inputs[index]
        if index==0:
            inputs_list.append(inputs_index.eval())
        else:
            inputs_list.append([int(i) for i in inputs_index.shape])
    return inputs_list
getInputsConcat.num=2

def getInputsReshape(size, op):
    '''
    Reshape专用的输入读取函数
    从op的输入里将各个输入的样式加入inputs_list，如果第二个参数有-1，则将-1对应的值算出，并将其替换
    :param size: 需要解析的inputs数量
    :param op: 被解析的对象，此处一定为Reshape
    :return: 返回解析完成的列表
    '''
    inputs_list = []
    for index in range(size):
        inputs_index = op.inputs[index]
        if index == 0:
            temp_shape = [int(i) for i in inputs_index.shape]
            temp_shape_prod = prod(temp_shape)
            inputs_list.append(temp_shape)
        elif index == 1:
            temp_shape = [int(i) for i in inputs_index.eval()]
            if -1 in temp_shape:
                shape_prod = 1
                for shape_index in range(len(temp_shape)):
                    if int(temp_shape[shape_index]) == -1:
                        shape_prod *= 1
                        target_index = shape_index
                    else:
                        shape_prod *= int(temp_shape[shape_index])
                temp_shape[target_index] = int(temp_shape_prod/shape_prod)
                inputs_list.append(temp_shape)
            else:
                inputs_list.append(temp_shape)
    return inputs_list
getInputsReshape.num=2

def getAttrs(attrname_list, op):
    '''
    使用属性名列表，作为key值，获取op中对应的属性值
    :param attrname_list: 存放属性名的列表
    :param op: 被解析的对象，此处为一般的TensorFlow对象
    :return: 返回属性值列表
    '''
    attrs = []
    for key in attrname_list:
        val = op.get_attr(key)
        if type(val) == type(b''):
            attrs.append(val.decode('utf-8'))
        else:
            attrs.append(val)
    return attrs
getAttrs.num=2

def getOutputs(size, op):
    '''
    获取输出的样式，返回列表
    :param size: 单个op获取输出的数量
    :param op: 被获取输出的op对象
    :return: 返回装有输出样式的列表
    '''
    outputs = []
    for index in range(size):
        outputs.append([int(i) for i in op.outputs[index].shape])
    return outputs
getOutputs.num=2

def handleConcat(inputs_list):
    '''
    处理Concat，获取输出样式
    :param inputs_list: Concat的所有输入
    :return: Concat的输出样式
    '''
    concat_dim = int(inputs_list[0])
    outputshape = [int(i) for i in inputs_list[1]]
    switch = 0
    for index in range(1,len(inputs_list)):
        switch += int(inputs_list[index][concat_dim])
    outputshape[concat_dim] = switch
    return outputshape
handleConcat.num=1

def inamount(in_list, op):
    """
    计算输入量，需要考虑输入的op的类型
    :param in_list: 输入，list(list)
    :param op: 需要计算输入量的op对象
    :return: 考虑类型计算出的输入量
    """
    type_bytes_dict = get_config_type_bytes()
    result = 0
    for index,shape in enumerate(in_list):
        inputs_index = op.inputs[index]
        if inputs_index.op.type in ['Placeholder','Const']:
            inputs_index_valtype = inputs_index.op.get_attr('dtype')
        else:
            inputs_index_valtype = inputs_index.op.get_attr('T')
        if inputs_index_valtype in type_bytes_dict:
            power_bytes = type_bytes_dict[inputs_index_valtype]
        else:
            power_bytes = 1
        result += prod(shape) * power_bytes
    return result
inamount.num = 2

def outamount(out_list, op):
    """
    计算输出量，并考虑本op的类型
    :param out_list: 输出，list(list)
    :param op: 计算输出量的op对象
    :return: 考虑类型计算出的输出量
    """
    type_bytes_dict = get_config_type_bytes()
    result = 0
    if op.type in ['Placeholder','Const']:
        outputs_valtype = op.get_attr('dtype')
    else:
        outputs_valtype = op.get_attr('T')
    if outputs_valtype in type_bytes_dict:
        power_bytes = type_bytes_dict[outputs_valtype]
    else:
        power_bytes = 1
    for index,shape in enumerate(out_list):
        result += prod(shape) * power_bytes
    return result
outamount.num = 2